/*
 Copyright 2013--Present JMM_PROGNAME
 
 This file is distributed under the terms of the JMM_PROGNAME License.
 
 You should have received a copy of the JMM_PROGNAME License.
 If not, see <JMM_PROGNAME WEBSITE>.
*/
// CREATED    : 7/15/2015
// LAST UPDATE: 10/1/2015

#include "statsxx/optimization/line_search_methods/ConjugateGradient.hpp"

// STL
#include <cmath>      // std::sqrt()
#include <cstdlib>    // std::abs()
#include <functional> // std::function<>
#include <limits>     // std::numeric_limits<>
#include <tuple>      // std::tuple<>, std::make_tuple()

// jScience
#include "jScience/physics/consts.hpp" // golden_ratio2


inline ConjugateGradient::ConjugateGradient()
{
    this->method = "PR";
    this->ftol = 1.0e-4;
    this->gtol = 1.0e-4;
    this->tol = 1.0e-3;
    this->iter_max = 100;
}

inline ConjugateGradient::~ConjugateGradient() {};


//========================================================================
//========================================================================
//
// NAME: std::tuple<int,
//                  double,
//                  double> ConjugateGradient::minimize(std::function<double(const Vector<double> &)> f,
//                                                      std::function<Vector<double>(const Vector<double> &)> df,
//                                                      Vector<double> x)
//
// DESC: Minimize the function f, using the method of conjugate gradients.
//
// INPUT:
//     std::function<double(const Vector<double> &)> f :
//     std::function<Vector<double>(const Vector<double> &)> df :
//     Vector<double> x :
//
// OUTPUT:
//     int info :
//     Vector<double> x :
//     double fx :
//
// NOTES:
//     ! This implementation is based on ...
//          R. Fletcher and C. M. Reeves "Function minimization by conjugate gradients", The Computer Journal 7, 149-154 (1964)
//     ! ... with the modification by Polak & Ribiere
//
//========================================================================
//========================================================================
inline std::tuple<
                  int,
                  Vector<double>,
                  double
                  > ConjugateGradient::minimize(
                                                      std::function<double(const Vector<double> &)> f,
                                                      std::function<Vector<double>(const Vector<double> &)> df,
                                                      Vector<double> x
                                                      )
{
    int info = 1;
    // Vector<double> x
    double fx;
    
    // calculate the initial function (fx), gradient (g), and conjugate direction (p)
    fx = f(x);
    Vector<double> g = df(x);
    Vector<double> p = -g;
    
    if(dot_product(g, g) == 0.0)
    {
        return std::make_tuple(0, x, fx);
    }
    
    for(int iter = 0; iter < this->iter_max; ++iter)
    {
        // both fx_{i+1}/fx_i and g_{i+1}/g_i are needed below
        double fxi = fx;
        Vector<double> gi = g;

        // line minimize f(x) along the conjugate direction p
        int lm_info;
        std::tie(lm_info, x, fx) = this->line_minimize(f, x, p);

        // note: convergence in the line minimization is not required
       
        // compute new gradients
        g = df(x);
        
        // check for convergence
        if(this->converged(fx, fxi, g, p))
        {
            info = 0;
            
            break;
        }
        
        // calculate a new conjugate direction
        double gigi = dot_product(gi, gi);
        
        double beta;
        
        if(this->method == "PR")
        {
            beta = dot_product((g - gi), g)/gigi;
        }
        else // if(this->method == "FR")
        {
            beta = dot_product(g, g)/gigi;
        }
        
        p = -g + beta*p;
    }
    
    return std::make_tuple(info, x, fx);
}

inline bool ConjugateGradient::converged(const double f, const double fi, const Vector<double> &g, const Vector<double> &pi)
{
    // convergence criteria: check the change in the function f()
    if((2.0*std::abs(f - fi)) < (this->ftol*(std::abs(f) + std::abs(fi))))
    {
        return true;
    }
    
    // convergence crieria: vanishing gradient
    double gg = dot_product(g, g);
    
    if(gg < 0.0)
    {
        return true;
    }
    
    // convergence criteria: check how far the angle between g_{i+1} and p_i differs from  the theoretical 90 deg
    double gp = dot_product(g, pi);
    if( (gp*gp/(gg*dot_product(pi, pi))) < this->tol )
    {
        return true;
    }
    
    return false;
}

